# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from collections.abc import Sequence
import dataclasses
import functools
import itertools as it
from typing import Callable, TypeVar, Any, Union

import numpy as np

import jax.numpy as jnp
from jax import dtypes
from jax import lax

from jax._src import api
from jax._src import linear_util as lu
from jax._src import config
from jax._src import core
from jax._src import custom_derivatives
from jax._src import effects
from jax._src import pjit
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import tree_util as jtu
from jax._src.ad_util import SymbolicZero
from jax._src.api_util import flatten_fun
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.tree_util import tree_flatten
from jax._src.tree_util import tree_map
from jax._src.tree_util import tree_unflatten
from jax._src.typing import Array
from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip,
                           unzip3, weakref_lru_cache, HashableWrapper)

source_info_util.register_exclusion(__file__)
traceback_util.register_exclusion(__file__)

map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip

Bool = Union[bool, Array]
Int = Union[int, Array]
ErrorCategory = type['JaxException']
Payload = list[Union[np.ndarray, Array]]
PyTreeDef = jtu.PyTreeDef
Out = TypeVar('Out')

## Utils

def popattr(obj, attrname):
  val = getattr(obj, attrname)
  delattr(obj, attrname)
  return val

def setnewattr(obj, name, val):
  sentinel = object()
  assert getattr(obj, name, sentinel) is sentinel
  setattr(obj, name, val)

# Concrete errors

class JaxException(Exception):
  """Python exception which can contain an error message with JAX run-time info."""

  def __init__(self, traceback_info):
    self.traceback_info = traceback_info
    # TODO(lenamartens): re-enable tracebacks when they don't leak tracers.
    # self.with_traceback(self.traceback_info)

  def __init_subclass__(cls):
    jtu.register_pytree_node_class(cls)

  def tree_flatten(self):
    return ([], self.traceback_info)

  @classmethod
  def tree_unflatten(cls, metadata, payload):
    del payload
    return cls(metadata)

  def get_effect_type(self) -> ErrorEffect:
    raise NotImplementedError


@functools.total_ordering
@dataclasses.dataclass(eq=True, frozen=True)
class ErrorEffect(effects.Effect):
  error_type: type[JaxException]
  shape_dtypes: tuple[api.ShapeDtypeStruct, ...]

  def __lt__(self, other: ErrorEffect):
    shape_dtypes = lambda x: tuple((sd.shape, str(sd.dtype))  # dtype is not comparable
                                   for sd in x.shape_dtypes)
    unpack = lambda x: (str(x.error_type), shape_dtypes(x))
    return (unpack(self) < unpack(other))

effects.control_flow_allowed_effects.add_type(ErrorEffect)
effects.lowerable_effects.add_type(ErrorEffect)

class DivisionByZeroError(JaxException):

  def __str__(self):
    return 'division by zero'

  def get_effect_type(self):
    return ErrorEffect(DivisionByZeroError, ())

class NaNError(JaxException):

  def __init__(self, traceback_info, primitive_name):
    super().__init__(traceback_info)
    self.prim = primitive_name

  def tree_flatten(self):
    return ([], (self.traceback_info, self.prim))

  @classmethod
  def tree_unflatten(cls, metadata, _):
    return cls(*metadata)

  def get_effect_type(self):
    return ErrorEffect(NaNError, ())

  def __str__(self):
    return f'nan generated by primitive: {self.prim}.'

class OOBError(JaxException):

  def __init__(self, traceback_info, primitive_name, operand_shape, payload):
    super().__init__(traceback_info)
    self.prim = primitive_name
    self.operand_shape = operand_shape
    self._payload = payload

  def tree_flatten(self):
    return ([self._payload], (self.traceback_info, self.prim, self.operand_shape))

  @classmethod
  def tree_unflatten(cls, metadata, payload):
    return cls(*metadata, payload[0])

  def __str__(self):
    return (f'out-of-bounds indexing for array of '
            f'shape {self.operand_shape}: '
            f'index {self._payload[0]} is out of bounds for axis '
            f'{self._payload[1]} with size {self._payload[2]}. ')

  def get_effect_type(self):
    return ErrorEffect(OOBError, (api.ShapeDtypeStruct((3,), jnp.int32),))

class FailedCheckError(JaxException):

  def __init__(self, traceback_info, fmt_string, *a, **k):
    super().__init__(traceback_info)
    self.fmt_string = fmt_string
    self.args = a
    self.kwargs = k

  def tree_flatten(self):
    return ((self.args, self.kwargs),  # leaves
            (self.traceback_info, self.fmt_string))  # treedef

  @classmethod
  def tree_unflatten(cls, metadata, payload):
    args, kwargs = payload
    return cls(*metadata, *args, **kwargs)

  def __str__(self):
    return (self.fmt_string.format(*self.args, **self.kwargs)
            + ' (`check` failed)')

  def get_effect_type(self):
    vals = jtu.tree_leaves((self.args, self.kwargs))
    return ErrorEffect(
        FailedCheckError,
        tuple(api.ShapeDtypeStruct(x.shape, x.dtype) for x in vals))

@dataclasses.dataclass
class BatchedError(JaxException):
  error_mapping: dict[tuple[int, ...], JaxException]

  def __post_init__(self):
    traceback_info = list(self.error_mapping.values())[0].traceback_info
    super().__init__(traceback_info)


  def __str__(self):
    return '\n'.join(f'at mapped index {", ".join(map(str, idx))}: {e}'
                     for idx, e in self.error_mapping.items())


# Error Value

@jtu.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class Error:
  _pred: dict[ErrorEffect, Bool]
  _code: dict[ErrorEffect, Int]
  _metadata: dict[Int, PyTreeDef]  # mapping of code to JaxException treedef.
  _payload: dict[ErrorEffect, Payload]

  def get(self) -> str | None:
    """Returns error message if error happened, None if no error happened."""
    exp = self.get_exception()
    if exp is not None:
      return str(exp)
    return None

  def get_exception(self) -> JaxException | None:
    """Returns Python exception if error happened, None if no error happened."""
    if any(map(np.shape, self._pred.values())):
      return self._get_batched_exception()
    else:
      min_code = None
      cur_effect = None
      for error_effect, code in self._code.items():
        if self._pred[error_effect]:
          if min_code is None or code < min_code:
            min_code = code
            cur_effect = error_effect

      if cur_effect is not None:
        return tree_unflatten(self._metadata[int(min_code)],  # type: ignore
                              self._payload[cur_effect])
    return None

  def throw(self):
    _check_error(self)

  def __str__(self):
    return f'Error({self.get()})'

  # Internal helpers

  def _get_batched_exception(self) -> BatchedError | None:
    shape = np.shape(list(self._pred.values())[0])
    error_mapping = {}
    for idx in np.ndindex(*shape):
      min_code = None
      cur_effect = None
      for error_effect, code in self._code.items():
        if self._pred[error_effect][idx]:   # type: ignore
          if min_code is None or code[idx] < min_code:
            min_code = code[idx]   # type: ignore
            cur_effect = error_effect

      if cur_effect is not None:
        payload = tree_map(lambda x, i=idx: x[i], self._payload[cur_effect])
        jax_error = tree_unflatten(self._metadata[int(min_code)], payload)  # type: ignore
        error_mapping[idx] = jax_error
    if error_mapping:
      return BatchedError(error_mapping)
    else:
      return None

  def _update(self, effect_type: ErrorEffect, pred, code, metadata, payload):
    new_errs = {**self._pred, **{effect_type: pred}}  # type: ignore
    new_codes = {**self._code, **{effect_type: code}}  # type: ignore
    new_payload = {**self._payload, **{effect_type: payload}}  # type: ignore
    new_metadata = {**self._metadata, **metadata}
    return Error(new_errs, new_codes, new_metadata, new_payload)

  def _add_placeholder_effects(self, effects: set[ErrorEffect]):
    """Fill out Error with `effects` and np.ones arrays of their payloads."""
    new_err = self._pred.copy()
    new_code = self._code.copy()
    new_payload = self._payload.copy()
    for effect in effects:
      if effect not in self._pred.keys():
        new_err[effect] = False
        new_payload[effect] = list(
            tree_map(lambda a: jnp.ones(a.shape, a.dtype), effect.shape_dtypes))
        # The error value associated with this effect will never become True, so
        # we don't need to set a meaningful code.
        new_code[effect] = -1
    return Error(new_err, new_code, self._metadata, new_payload)

  def _replace(self, *args, **kwargs):
    return dataclasses.replace(self, *args, **kwargs)

  # PyTree methods

  def tree_flatten(self):
    return ((self._pred, self._code, self._payload), (self._metadata))

  @classmethod
  def tree_unflatten(cls, metadata, data):
    pred, code, payload = data
    return cls(pred, code, metadata, payload)

init_error = Error({}, {}, {}, {})  # value used as initial (empty) error.
next_code = it.count(1).__next__  # globally unique ids, could be uuid4

def assert_func(error: Error, pred: Bool, new_error: JaxException) -> Error:
  code = next_code()
  effect_type = new_error.get_effect_type()
  new_payload, new_metadata = tree_flatten(new_error)
  return update_error(error, pred, code, {code: new_metadata}, new_payload, effect_type)

def update_error(error, pred, code, metadata, payload, effect_type):
  err_of_type = error._pred.get(effect_type, False)
  out_err = err_of_type | pred
  out_code = lax.select(err_of_type, error._code.get(effect_type, -1), code)
  cur_payload = error._payload.get(effect_type, None)
  if cur_payload is not None:
    out_payload = tree_map(functools.partial(lax.select, err_of_type), cur_payload, payload)
  else:
    out_payload = payload
  return error._update(effect_type, out_err, out_code, metadata, out_payload)


## Checkify transformation for plumbing functional error values.

@lu.transformation_with_aux
def _flatten_and_get_error_metadata_thunk(*invals):
  error, out = yield invals, {}
  out_vals, out_tree = jtu.tree_flatten((error, out))
  yield out_vals, (out_tree, set(error._pred.keys()))

def default_checkify_rule(primitive: core.Primitive, error: Error,
                          enabled_errors, *invals: core.Value,
                          **params: Any) -> tuple[Error, Sequence[core.Value]]:
  """Default rule for primitives in `checkify` interpreter."""
  if 'call_jaxpr' not in params:
    # Default non-HOP case: just call primitive and don't update error.
    return error, primitive.bind(*invals, **params)

  # Code below handles call- and map-primitives, by recursively calling
  # checkify_jaxpr.
  err_vals, err_tree = jtu.tree_flatten(error)
  num_error_vals = len(err_vals)
  if 'donated_invars' in params:
    params = dict(params, donated_invars=(*[False]*num_error_vals,
                                          *params['donated_invars']))

  # call_jaxpr handling
  call_jaxpr = params.pop('call_jaxpr')
  if isinstance(call_jaxpr, core.ClosedJaxpr):  # handle closed_call_p
    jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts
  else:
    jaxpr, consts = call_jaxpr, ()
  consts_ = tuple(HashableWrapper(c) for c in consts)
  partial_checkify = lu.hashable_partial(lu.wrap_init(
      checkify_jaxpr_flat_hashable), jaxpr, consts_, enabled_errors, err_tree)
  partial_checkify, metadata = _flatten_and_get_error_metadata_thunk(
      partial_checkify)

  # map-specific params handling.
  if isinstance(primitive, core.MapPrimitive):
    # Update `in_axes` and `out_axes_thunk` params for map primitive.
    out_val_axes = params.pop('out_axes')

    @as_hashable_function(closure=out_val_axes)
    def out_axes_thunk():
      out_err_num = metadata()[0].num_leaves - len(out_val_axes)
      return (*(0,)*out_err_num, *out_val_axes)

    params = dict(params,
                  in_axes=(*(None,)*num_error_vals, *params['in_axes']),
                  out_axes_thunk=out_axes_thunk)

  all_vals = primitive.bind(partial_checkify, *err_vals, *invals, **params)

  out_tree, _ = metadata()
  error, out_vals = tree_unflatten(out_tree, all_vals)
  if isinstance(primitive, core.MapPrimitive):
    error = _reduce_any_error(error)
  return error, out_vals

def get_shaped_aval(val):
  return core.raise_to_shaped(core.get_aval(val))

def checkify_jaxpr(jaxpr: core.ClosedJaxpr, enabled_errors,
                   error: Error, *args) -> tuple[Error, list[core.Value]]:
  err_vals, err_tree = jtu.tree_flatten(error)
  return checkify_jaxpr_flat(jaxpr.jaxpr, jaxpr.consts,
                             enabled_errors, err_tree, *err_vals, *args)

def checkify_jaxpr_flat(jaxpr: core.Jaxpr, consts: Sequence[core.Value],
                        enabled_errors, err_tree: PyTreeDef,
                        *args: core.Value) -> tuple[Error, list[Any]]:
  env: dict[core.Var, Any] = {}
  err_vals, in_args = split_list(args, [err_tree.num_leaves])
  error = jtu.tree_unflatten(err_tree, err_vals)

  last_used = core.last_used(jaxpr)

  def read_env(var: core.Atom):
    if isinstance(var, core.Literal):
      return var.val
    return env[var]

  def write_env(var: core.Var, val: Any):
    env[var] = val

  map(write_env, jaxpr.constvars, consts)
  map(write_env, jaxpr.invars, in_args)

  # interpreter loop
  for eqn in jaxpr.eqns:
    invals = map(read_env, eqn.invars)
    checkify_rule = error_checks.get(
        eqn.primitive, functools.partial(default_checkify_rule, eqn.primitive))
    name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
    with source_info_util.user_context(eqn.source_info.traceback,
                                       name_stack=name_stack):
      error, outvals = checkify_rule(error, enabled_errors,
                                     *invals, **eqn.params)
    if eqn.primitive.multiple_results:
      map(write_env, eqn.outvars, outvals)
    else:
      write_env(eqn.outvars[0], outvals)
    core.clean_up_dead_vars(eqn, env, last_used)

  return error, map(read_env, jaxpr.outvars)

def checkify_jaxpr_flat_hashable(jaxpr, hashable_consts, enabled_errors,
                                 err_tree, *args):
  consts = tuple(c.x for c in hashable_consts)
  return checkify_jaxpr_flat(jaxpr, consts, enabled_errors, err_tree, *args)

@lu.transformation_with_aux
def flatten_fun_output(*args):
  ans = yield args, {}
  yield tree_flatten(ans)


def _reduce_any_error(error: Error):
  out_error = init_error
  for error_effect in error._pred.keys():
    errs, codes, payloads = (error._pred[error_effect],
                             error._code[error_effect],
                             error._payload[error_effect])
    reduced_idx = jnp.argsort(errs)[-1]
    pred, code, payload = tree_map(lambda x, idx=reduced_idx: x[idx],
                                   (errs, codes, payloads))
    out_error = out_error._update(error_effect, pred, code, {}, payload)

  out_error = out_error._replace(_metadata=error._metadata)
  return out_error

## check_p primitive

check_p = core.Primitive('check')
check_p.multiple_results = True  # zero results


def _pp_check(eqn, context, settings) -> core.pp.Doc:
  annotation = (source_info_util.summarize(eqn.source_info)
                if settings.source_info else None)
  name_stack_annotation = (f'[{eqn.source_info.name_stack}]'
                           if settings.name_stack else None)
  trimmed_params = sorted((k, v) for (k, v) in eqn.params.items()
                          if k != "err_tree")
  rhs = [core.pp.text(eqn.primitive.name, annotation=name_stack_annotation),
         core.pp_kv_pairs(trimmed_params, context, settings),
         core.pp.text(" ") + core.pp_vars(eqn.invars, context)]
  return core.pp.concat([core.pp.text("", annotation), *rhs])

core.pp_eqn_rules[check_p] = _pp_check

# TODO(lenamartens): inherit from Exception instead of ValueError.
class JaxRuntimeError(ValueError):
  pass

@check_p.def_impl
def check_impl(*args, err_tree, debug):
  if debug:
    # NOOP (check will only trigger when discharged)
    return []
  error = tree_unflatten(err_tree, args)
  exc = error.get_exception()
  if exc:
    filtered_tb = traceback_util.filter_traceback(
        exc.traceback_info.as_python_traceback())
    exc.with_traceback(filtered_tb)
    raise JaxRuntimeError(str(exc)) from exc
  return []

@check_p.def_effectful_abstract_eval
def check_abstract_eval(*args, err_tree, debug):
  del debug
  return [], set(tree_unflatten(err_tree, args)._pred.keys())

# TODO(lenamartens) add in-depth error explanation to link to in module docs.
functionalization_error = ValueError(
    'Cannot abstractly evaluate a checkify.check which was not'
    ' functionalized. This probably means you tried to stage'
    ' (jit/scan/pmap/...) a `check` without functionalizing it'
    ' through `checkify.checkify`.'
    )

def check_lowering_rule(ctx, *args, err_tree, debug):
  if debug:
    # NOOP (check will only trigger when discharged)
    return []
  if not config.xla_runtime_errors.value:
    raise functionalization_error

  out_op, _, _ = mlir.emit_python_callback(
      ctx, callback=functools.partial(python_err, err_tree),
      token=None,
      operands=args,
      operand_avals=list(ctx.avals_in),
      result_avals=list(ctx.avals_out),
      has_side_effect=True)
  return out_op

def check_lowering_rule_unsupported(*a, debug, **k):
  if debug:
    return []
  raise functionalization_error

def python_err(err_tree, *args):
  error = tree_unflatten(err_tree, args)
  _check_error(error)
  return []

mlir.register_lowering(check_p, check_lowering_rule_unsupported,
                       platform='tpu')
mlir.register_lowering(check_p, check_lowering_rule,
                       platform='cpu')
mlir.register_lowering(check_p, check_lowering_rule,
                       platform='gpu')

def check_batching_rule(batched_args, batch_dims, *, err_tree, debug):
  size = next(x.shape[dim] for x, dim in zip(batched_args, batch_dims)
              if dim is not batching.not_mapped)
  batched_args = (batching.bdim_at_front(a, d, size)
                  for a, d in zip(batched_args, batch_dims))
  err = tree_unflatten(err_tree, batched_args)
  _check_error(err, debug=debug)
  return [], []
batching.primitive_batchers[check_p] = check_batching_rule

def check_jvp_rule(primals, _, *, err_tree, debug):
  # Check primals, discard tangents.
  check_p.bind(*primals, err_tree=err_tree, debug=debug)
  return [], []
ad.primitive_jvps[check_p] = check_jvp_rule

## checkify rules

ErrorCheckRule = Callable  # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error)
error_checks: dict[core.Primitive, ErrorCheckRule] = {}


def get_traceback():
  return source_info_util.current().traceback

def nan_error_check(prim, error, enabled_errors, *in_vals, **params):
  out = prim.bind(*in_vals, **params)
  err = check_nans(prim, error, enabled_errors, out)
  return err, out

def check_nans(prim, error, enabled_errors, out):
  if NaNError not in enabled_errors:
    return error

  def isnan(x):
    if jnp.issubdtype(x.dtype, dtypes.prng_key):
      return False
    return jnp.any(jnp.isnan(x))

  any_nans = (jnp.any(jnp.array([isnan(x) for x in out]))
              if prim.multiple_results else isnan(out))
  return assert_func(error, any_nans, NaNError(get_traceback(), prim.name))


# All primitives which can generate a NaN.
nan_primitives = [lax.acos_p, lax.acosh_p, lax.add_p, lax.asin_p, lax.asinh_p,
                  lax.atan2_p, lax.atan_p, lax.atanh_p, lax.bessel_i0e_p,
                  lax.bessel_i1e_p, lax.cbrt_p, lax.conv_general_dilated_p,
                  lax.cos_p, lax.cosh_p, lax.cumlogsumexp_p, lax.cummax_p,
                  lax.cummin_p, lax.cumprod_p, lax.cumsum_p, lax.digamma_p,
                  lax.dot_general_p, lax.erf_inv_p, lax.erf_p, lax.erfc_p,
                  lax.exp_p, lax.expm1_p, lax.fft_p, lax.igamma_grad_a_p,
                  lax.igamma_p, lax.igammac_p, lax.integer_pow_p, lax.lgamma_p,
                  lax.linear_solve_p, lax.log1p_p, lax.log_p, lax.logistic_p,
                  lax.mul_p, lax.pad_p, lax.pow_p, lax.psum_p,
                  lax.random_gamma_grad_p, lax.reduce_p, lax.reduce_prod_p,
                  lax.reduce_sum_p, lax.reduce_window_p,
                  lax.reduce_window_sum_p, lax.regularized_incomplete_beta_p,
                  lax.rem_p, lax.rng_uniform_p, lax.rsqrt_p, lax.sin_p,
                  lax.sinh_p, lax.sqrt_p, lax.sub_p, lax.tan_p, lax.tanh_p]

for _prim in nan_primitives:
  error_checks[_prim] = functools.partial(nan_error_check, _prim)


def dynamic_slice_error_check(error, enabled_errors, operand, *start_indices, slice_sizes):
  out = lax.dynamic_slice_p.bind(operand, *start_indices, slice_sizes=slice_sizes)

  if OOBError not in enabled_errors:
    return error, out

  operand_dims = np.array(operand.shape)
  slice_sizes = np.array(slice_sizes)
  start_indices = jnp.array(start_indices)
  oob_mask = (start_indices < 0) | (start_indices + slice_sizes > operand_dims)

  payload = oob_payload(oob_mask, start_indices, range(operand.ndim), operand.shape)
  error = assert_func(error, jnp.any(oob_mask), OOBError(get_traceback(), "dynamic_slice", operand.shape, payload))
  return error, out
error_checks[lax.dynamic_slice_p] = dynamic_slice_error_check

def dynamic_update_slice_error_check(error, enabled_errors, operand, update, *start_indices):
  out = lax.dynamic_update_slice_p.bind(operand, update, *start_indices)

  if OOBError not in enabled_errors:
    return error, out

  operand_dims = np.array(operand.shape)
  update_dims = np.array(update.shape)
  start_indices = jnp.array(start_indices)
  oob_mask = (start_indices < 0) | (start_indices + update_dims > operand_dims)

  payload = oob_payload(oob_mask, start_indices, range(operand.ndim), operand.shape)
  error = assert_func(error, jnp.any(oob_mask), OOBError(get_traceback(), "dynamic_update_slice", operand.shape, payload))
  return error, out
error_checks[lax.dynamic_update_slice_p] = dynamic_update_slice_error_check

def gather_error_check(error, enabled_errors, operand, start_indices, *,
                       dimension_numbers, slice_sizes, unique_indices,
                       indices_are_sorted, mode, fill_value):
  out = lax.gather_p.bind(
      operand, start_indices, dimension_numbers=dimension_numbers,
      slice_sizes=slice_sizes, unique_indices=unique_indices,
      indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value)

  if OOBError not in enabled_errors:
    return error, out

  # compare to OOB masking logic in lax._gather_translation_rule
  dnums = dimension_numbers
  operand_dims = np.array(operand.shape)
  num_batch_dims = len(start_indices.shape) - 1

  upper_bound = operand_dims[np.array(dnums.start_index_map)]
  upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)]
  upper_bound = jnp.expand_dims(upper_bound, axis=tuple(range(num_batch_dims)))
  oob_mask = (start_indices < 0) | (start_indices > upper_bound.astype(start_indices.dtype))

  payload = oob_payload(oob_mask, start_indices, dnums.start_index_map, operand.shape)
  error = assert_func(error, jnp.any(oob_mask), OOBError(get_traceback(), "gather", operand.shape, payload))
  return error, out
error_checks[lax.gather_p] = gather_error_check

def div_error_check(error, enabled_errors, x, y):
  """Checks for division by zero and NaN."""
  if DivisionByZeroError in enabled_errors:
    any_zero = jnp.any(jnp.equal(y, 0))
    error = assert_func(error, any_zero, DivisionByZeroError(get_traceback()))
  return nan_error_check(lax.div_p, error, enabled_errors, x, y)
error_checks[lax.div_p] = div_error_check

def oob_payload(oob_mask, indices, dims_map, operand_shape):
  # Get first OOB index, axis and axis size so it can be added to the error msg.
  flat_idx = jnp.argmin(jnp.logical_not(oob_mask))
  multi_idx = jnp.unravel_index(flat_idx, indices.shape)
  oob_axis = jnp.array(dims_map)[multi_idx[-1]]
  oob_axis_size = jnp.array(operand_shape)[oob_axis]
  oob_index = jnp.ravel(indices)[flat_idx]
  payload = jnp.array([oob_index, oob_axis, oob_axis_size], dtype=jnp.int32)
  return payload

def scatter_oob(operand, indices, updates, dnums):
  # Ref: see clamping code used in scatter_translation_rule
  slice_sizes = []
  pos = 0
  for i in range(len(operand.shape)):
    if i in dnums.inserted_window_dims:
      slice_sizes.append(1)
    else:
      slice_sizes.append(updates.shape[dnums.update_window_dims[pos]])
      pos += 1

  upper_bound = np.array([operand.shape[i] - slice_sizes[i]
                          for i in dnums.scatter_dims_to_operand_dims],
                         np.int64)
  upper_bound = np.minimum(upper_bound, np.iinfo(indices.dtype).max)
  upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
                                     (len(indices.shape) - 1,))

  lower_oob = jnp.less(indices, 0)
  upper_oob = jnp.greater(indices, upper_bound.astype(indices.dtype))
  oob_mask = jnp.logical_or(lower_oob, upper_oob)
  payload = oob_payload(oob_mask, indices,
                        dnums.scatter_dims_to_operand_dims, operand.shape)
  return jnp.any(oob_mask), payload

def scatter_error_check(prim, error, enabled_errors, operand, indices, updates,
                        *, update_jaxpr, update_consts, dimension_numbers,
                        indices_are_sorted, unique_indices, mode):
  """Checks if indices are within bounds and update does not generate NaN."""
  out = prim.bind(
      operand, indices, updates, update_jaxpr=update_jaxpr,
      update_consts=update_consts, dimension_numbers=dimension_numbers,
      indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
      mode=mode)

  if OOBError not in enabled_errors:
    return error, out

  out_of_bounds, payload = scatter_oob(operand, indices, updates, dimension_numbers)
  oob_error = OOBError(get_traceback(), prim.name, operand.shape, payload)
  error = assert_func(error, out_of_bounds, oob_error)
  error = check_nans(prim, error, enabled_errors, out)
  return error, out
error_checks[lax.scatter_p] = functools.partial(scatter_error_check, lax.scatter_p)
error_checks[lax.scatter_add_p] = functools.partial(scatter_error_check,
                                                    lax.scatter_add_p)
error_checks[lax.scatter_mul_p] = functools.partial(scatter_error_check,
                                                    lax.scatter_mul_p)
error_checks[lax.scatter_min_p] = functools.partial(scatter_error_check,
                                                    lax.scatter_min_p)
error_checks[lax.scatter_max_p] = functools.partial(scatter_error_check,
                                                    lax.scatter_max_p)

# HOP error check rules

@weakref_lru_cache
def jaxpr_to_checkify_jaxpr(
    jaxpr: core.ClosedJaxpr, enabled_errors, err_tree: PyTreeDef,
    *flat_err_and_in_vals) -> tuple[core.ClosedJaxpr, PyTreeDef, set[ErrorEffect]]:
  checkify_jaxpr_partial = functools.partial(checkify_jaxpr_flat, jaxpr.jaxpr,
                                             jaxpr.consts, enabled_errors,
                                             err_tree)
  fun = lu.wrap_init(checkify_jaxpr_partial)
  fun, metadata = _flatten_and_get_error_metadata_thunk(fun)

  new_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, flat_err_and_in_vals)
  checked_jaxpr = core.ClosedJaxpr(new_jaxpr, consts)
  out_tree, error_effects = metadata()
  return checked_jaxpr, out_tree, error_effects

def cond_error_check(error: Error, enabled_errors, index, *ops, branches, linear):
  # Get the error-effects out of all branches so the cond can be called with
  # a merged error with all these effects.
  err_vals, err_tree = jtu.tree_flatten(error)
  in_avals = map(get_shaped_aval, [*err_vals, *ops])
  def get_error_effects_from_jaxpr(jxpr):
    _, _, effects = jaxpr_to_checkify_jaxpr(jxpr, enabled_errors, err_tree,
                                            *in_avals)
    return effects
  effects = [get_error_effects_from_jaxpr(jxpr) for jxpr in branches]
  merged_error = error._add_placeholder_effects(set().union(*effects))
  err_vals, err_tree = jtu.tree_flatten(merged_error)
  new_linear = (*[False] * len(err_vals), *linear)

  # Update branch jaxprs to be checkified jaxprs.
  in_avals = map(get_shaped_aval, [*err_vals, *ops])
  new_branches, out_trees, _ = unzip3(
      jaxpr_to_checkify_jaxpr(
          jxpr, enabled_errors, err_tree, *in_avals) for jxpr in branches)

  err_and_outs = lax.cond_p.bind(
      index, *err_vals, *ops,
      branches=tuple(new_branches), linear=new_linear)

  # we need to merge metadata across out_trees (a tuple)
  err0, out = tree_unflatten(out_trees[0], err_and_outs)
  merged_metadata = err0._metadata
  for tr in out_trees[1:]:
    err, _ = tree_unflatten(tr, err_and_outs)
    merged_metadata = {**merged_metadata, **err._metadata}
  return err0._replace(_metadata=merged_metadata), out
error_checks[lax.cond_p] = cond_error_check

def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr,
                     num_consts, num_carry, linear, unroll, _split_transpose):

  consts, carry, xs = split_list(in_flat, [num_consts, num_carry])
  xs_mapped = [core.mapped_aval(length, 0, get_shaped_aval(val)) for val in xs]
  # Query body effects to create a merged error containing all effects (such
  # that in and out carried error are of the same type).
  err_vals, err_tree = jtu.tree_flatten(error)
  new_in_aval = map(get_shaped_aval, [*err_vals, *consts, *carry]) + xs_mapped
  _, _, effects = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
                                          err_tree, *new_in_aval)

  merged_error = error._add_placeholder_effects(effects)
  err_vals, err_tree = jtu.tree_flatten(merged_error)

  # Create checked-jaxpr, with the needed pre-processing on the inputs.
  new_in_aval = map(get_shaped_aval, [*err_vals, *consts, *carry]) + xs_mapped
  checked_jaxpr_, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
                                                        err_tree, *new_in_aval)

  tomove = ([False] * len(err_vals) + [True] * len(consts)
            + [False] * (len(carry) + len(xs)))
  checked_jaxpr = pe.move_binders_to_front(checked_jaxpr_, tomove)
  new_in_flat = [*consts, *err_vals, *carry, *xs]
  new_linear = (*[False] * len(err_vals), *linear)
  err_and_out = lax.scan_p.bind(
      *new_in_flat, reverse=reverse, length=length, jaxpr=checked_jaxpr,
      num_consts=len(consts), num_carry=len(carry)+len(err_vals),
      linear=new_linear, unroll=unroll, _split_transpose=_split_transpose)
  err, out = tree_unflatten(out_tree, err_and_out)
  return err, out

error_checks[lax.scan_p] = scan_error_check

def checkify_while_body_jaxpr(
    cond_jaxpr: core.ClosedJaxpr, body_jaxpr: core.ClosedJaxpr,
    enabled_errors, error: Error,
    c_consts_num: int) -> tuple[core.ClosedJaxpr, PyTreeDef, set[ErrorEffect]]:
  cond_f = core.jaxpr_as_fun(cond_jaxpr)
  body_f = core.jaxpr_as_fun(body_jaxpr)
  def new_body_f(*c_consts_and_vals):
    c_consts, vals = split_list(c_consts_and_vals, [c_consts_num])
    out = body_f(*vals)
    # This checks if the next cond application will error
    _ = cond_f(*c_consts, *out)
    return out
  new_body_f_ = lu.wrap_init(new_body_f)
  c_consts_avals = cond_jaxpr.in_avals[:c_consts_num]
  jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(new_body_f_, [*c_consts_avals,
                                                         *body_jaxpr.in_avals])
  closed_jaxpr = pe.close_jaxpr(jaxpr)
  err_vals, err_tree = jtu.tree_flatten(error)
  err_vals = map(get_shaped_aval, err_vals)
  flat_err_and_in_vals = [*err_vals, *c_consts_avals, *body_jaxpr.in_avals]
  jaxpr, out_tree, error_effects = jaxpr_to_checkify_jaxpr(
      closed_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals)
  return jaxpr, out_tree, error_effects


@weakref_lru_cache
def ignore_error_output_jaxpr(jaxpr, num_error_vals: int):
  """Constructs a checked jaxpr which does not output its error value."""
  consts = jaxpr.consts
  jaxpr = jaxpr.jaxpr
  new_jaxpr = jaxpr.replace(outvars=jaxpr.outvars[num_error_vals:])
  return core.ClosedJaxpr(new_jaxpr, consts)

def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
                           cond_jaxpr, body_nconsts, body_jaxpr):
  if cond_jaxpr.out_avals[0].shape:
    # TODO(lenamartens, sharadmv): support batched while.
    raise ValueError('Checkify does not support batched while-loops '
                     '(checkify-of-vmap-of-while). \nHint: if possible, move '
                     'the vmap to the outer level to get '
                     'vmap-of-checkify-of-while.')

  c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts])
  # Check if the first cond application will error.
  error, _ = checkify_jaxpr(cond_jaxpr, enabled_errors, error, *c_consts, *carry)

  _, _, error_effects = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr,
                                                  enabled_errors, error,
                                                  cond_nconsts)
  # merged error!
  error = error._add_placeholder_effects(error_effects)
  err_vals, err_tree = jtu.tree_flatten(error)
  checked_body_jaxpr_, body_out_tree, _ = checkify_while_body_jaxpr(
      cond_jaxpr, body_jaxpr, enabled_errors, error, cond_nconsts)
  num_error_vals = len(err_vals)
  to_move = ([False] * num_error_vals + [True] * cond_nconsts
             + [True] * body_nconsts + [False] * len(carry))
  checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move)

  cond_in_flat = [*err_vals, *c_consts, *carry]
  cond_in_flat = map(get_shaped_aval, cond_in_flat)
  checked_cond_jaxpr, _, _ = jaxpr_to_checkify_jaxpr(cond_jaxpr, enabled_errors,
                                                     err_tree, *cond_in_flat)
  compat_cond_jaxpr_ = ignore_error_output_jaxpr(checked_cond_jaxpr, num_error_vals)
  to_move = [False] * num_error_vals + [True] * cond_nconsts + [False] * len(carry)
  compat_cond_jaxpr = pe.move_binders_to_front(compat_cond_jaxpr_, to_move)

  new_in_flat = [*c_consts, *c_consts, *b_consts, *err_vals, *carry]
  all_out_vals = lax.while_p.bind(
      *new_in_flat, cond_nconsts=cond_nconsts, cond_jaxpr=compat_cond_jaxpr,
      body_nconsts=cond_nconsts+body_nconsts, body_jaxpr=checked_body_jaxpr)
  # body_out_tree will have all the metadata of cond because it executes a cond!
  error, out = tree_unflatten(body_out_tree, all_out_vals)
  return error, out
error_checks[lax.while_p] = while_loop_error_check

def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
                     in_shardings, out_shardings,
                     in_layouts, out_layouts,
                     resource_env, donated_invars, name, inline, keep_unused):
  # jaxpr to checked_jaxpr
  err_vals, err_tree = jtu.tree_flatten(error)
  new_vals_in = [*err_vals, *vals_in]
  in_avals = tuple(map(get_shaped_aval, new_vals_in))
  checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
                                                       err_tree, *in_avals)

  # Update pjit params to account for extra error values.
  num_error_vals = len(err_vals)
  num_out_error_vals = out_tree.num_leaves - len(out_shardings)

  sharding = sharding_impls.UNSPECIFIED
  new_in_shardings = (*[sharding] * num_error_vals, *in_shardings)
  new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings)
  new_in_layouts = (*[None] * num_error_vals, *in_layouts)
  new_out_layouts = (*[None] * num_out_error_vals, *out_layouts)
  new_donated_invars = (*[False] * num_error_vals, *donated_invars)

  err_and_out = pjit.pjit_p.bind(
      *new_vals_in,
      jaxpr=checked_jaxpr,
      in_shardings=new_in_shardings,
      out_shardings=new_out_shardings,
      in_layouts=new_in_layouts,
      out_layouts=new_out_layouts,
      resource_env=resource_env,
      donated_invars=new_donated_invars,
      name=name,
      inline=inline,
      keep_unused=keep_unused,
  )
  return tree_unflatten(out_tree, err_and_out)
error_checks[pjit.pjit_p] = pjit_error_check

def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts,
                         jvp_jaxpr_thunk, call_jaxpr, **params):
  # The types to have in mind are:
  #   jvp : (a -> b) -> (a, T a) -> (b, T b)
  #   checkify : (a -> b) -> a -> Err b
  #   jvp-of-checkify : (a -> b) -> (a, T a) -> (Err b, T (Err b))
  # where because Err is a pytree, we necessarily have T (Err b) = Err' (T b)
  # where the other Err' components are trivial (of float0 dtype).
  # Semantically, we don't add checks to the JVP rule. To check the result of a
  # JVP rule, one must instead use checkify-of-jvp. Thus this implementation
  # just forwards the input error and code (and trivial tangents) to the output.
  err_vals, err_tree = jtu.tree_flatten(in_err)
  partial_checkify = lu.wrap_init(
      functools.partial(checkify_jaxpr_flat, call_jaxpr.jaxpr,
                        call_jaxpr.consts, enabled_errors, err_tree))
  partial_checkify, f_metadata = _flatten_and_get_error_metadata_thunk(
      partial_checkify)
  jvp = lift_jvp(err_tree.num_leaves, num_consts, jvp_jaxpr_thunk)
  jvp, jvp_out_tree = flatten_fun_output(jvp)
  all_outs = custom_derivatives.custom_jvp_call_p.bind(
      partial_checkify, jvp, *err_vals, *in_vals, **params)
  fst, out_metadata = lu.merge_linear_aux(f_metadata, jvp_out_tree)
  if fst:
    err_and_out_tree, _ = out_metadata
    out_err, out_vals = tree_unflatten(err_and_out_tree, all_outs)
  else:
    err_vals, out_vals = split_list(all_outs, [len(err_vals)])
    # forward input error to output
    out_err = jtu.tree_unflatten(err_tree, err_vals)
  return out_err, out_vals
error_checks[custom_derivatives.custom_jvp_call_p] = custom_jvp_call_rule

# Compared to custom_derivatives.lift_jvp, we're handling the extra inputs and
# outputs that checkify adds (just forwarding the error data's primal and
# tangent components). The jaxpr in jvp_jaxpr_thunk doesn't expect those.
# TODO(mattjj): can we simplify this, or dedup with custom_derivatives.lift_jvp?
# Adding another layer of lu.transformation was tricky, though maybe doable.
def lift_jvp(num_errs, num_consts, jvp_jaxpr_thunk):
  @lu.wrap_init
  def jvp(*xs):
    n, ragged = divmod(len(xs), 2)
    assert not ragged
    primals, tangents = xs[num_consts+num_errs:n], xs[n+num_consts+num_errs:]
    zeros = [type(t) is SymbolicZero for t in tangents]
    jvp_jaxpr, jvp_consts, out_zeros = jvp_jaxpr_thunk(*zeros)
    nonzero_tangents = [t for t in tangents if type(t) is not SymbolicZero]
    out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents)
    out_primals, nz_out_tangents = split_list(out, [len(out_zeros)])
    nz_out_tangents_ = iter(nz_out_tangents)
    out_tangents = [SymbolicZero(core.get_aval(p).at_least_vspace())
                    if z else next(nz_out_tangents_)
                    for p, z in zip(out_primals, out_zeros)]
    assert next(nz_out_tangents_, None) is None
    primal_errs = xs[num_consts:num_consts+num_errs]
    tangent_errs = xs[n+num_consts:n+num_consts+num_errs]
    return [*primal_errs, *out_primals, *tangent_errs, *out_tangents]
  return jvp

def custom_vjp_call_jaxpr_rule(in_err, enabled_errors, *in_vals, fun_jaxpr,
                               fwd_jaxpr_thunk, num_consts, bwd, out_trees,
                               symbolic_zeros):
  err_vals, err_tree = jtu.tree_flatten(in_err)
  num_errs = err_tree.num_leaves
  checkified_fun = lu.wrap_init(
      functools.partial(checkify_jaxpr_flat, fun_jaxpr.jaxpr,
                        fun_jaxpr.consts, enabled_errors, err_tree))
  checkified_fun, fun_metadata = _flatten_and_get_error_metadata_thunk(
      checkified_fun)

  @lu.wrap_init
  def checkified_fwd(*args):
    # TODO(lenamartens, sharadmv): why not checkify here?
    xs, zeros = args[::2], args[1::2]
    xs, zeros = xs[num_errs:], zeros[num_errs:]
    fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk(*zeros)
    xs_without_consts = xs[num_consts:]
    return core.eval_jaxpr(fwd_jaxpr, fwd_consts, *xs_without_consts)

  bwd_ = lambda *args: (*(None,)*num_errs, *bwd(*args))
  checkified_fwd, fwd_out_tree = flatten_fun_output(checkified_fwd)
  all_outs = custom_derivatives.custom_vjp_call_p.bind(
      checkified_fun, checkified_fwd, bwd_, *err_vals, *in_vals, out_trees=out_trees,
      symbolic_zeros=symbolic_zeros)
  fst, out_metadata = lu.merge_linear_aux(fun_metadata, fwd_out_tree)
  if fst:
    err_and_out_tree, _ = out_metadata
    out_err, out_vals = tree_unflatten(err_and_out_tree, all_outs)
  else:
    out_err, out_vals = in_err, all_outs
  return out_err, out_vals
error_checks[custom_derivatives.custom_vjp_call_jaxpr_p] = custom_vjp_call_jaxpr_rule


def check_discharge_rule(error, enabled_errors, *args, err_tree, debug):
  del debug
  new_error = tree_unflatten(err_tree, args)
  # Split up new_error into error to be functionalized if it's included in
  # enabled_errors (=discharged_error) and an error to be defunctionalized if
  # it's not included (=recharged_error)
  discharged_error = error
  recharged_error = init_error
  for error_effect in new_error._pred.keys():
    pred = new_error._pred[error_effect]
    code = new_error._code[error_effect]
    payload = new_error._payload[error_effect]
    if error_effect.error_type in enabled_errors:
      discharged_error = update_error(discharged_error, pred, code, {}, payload,
                                      error_effect)
    else:
      recharged_error = update_error(recharged_error, pred, code, {}, payload,
                                     error_effect)

  discharged_error = discharged_error._replace(
      _metadata={**new_error._metadata, **discharged_error._metadata})
  recharged_error = recharged_error._replace(_metadata=new_error._metadata)
  # TODO(lenamartens): we actually need to recharge, but this would be a
  # breaking API change so leaving for a follow-up.
  # check_error(recharged_error)
  return discharged_error, []
error_checks[check_p] = check_discharge_rule


## checkify public api

user_checks = frozenset({FailedCheckError})
nan_checks = frozenset({NaNError})
index_checks = frozenset({OOBError})
div_checks = frozenset({DivisionByZeroError})
float_checks = nan_checks | div_checks
automatic_checks = float_checks | index_checks
all_checks = automatic_checks | user_checks


def checkify(f: Callable[..., Out],
             errors: frozenset[ErrorCategory] = user_checks
             ) -> Callable[..., tuple[Error, Out]]:
  """Functionalize `check` calls in `fun`, and optionally add run-time error checks.

  Run-time errors are either user-added :func:`~check` assertions, or
  automatically added checks like NaN checks, depending on the ``errors``
  argument.

  The returned function will return an Error object `err` along with the output
  of the original function. ``err.get()`` will either return ``None`` (if no
  error occurred) or a string containing an error message. This error message
  will correspond to the first error which occurred. ``err.throw()`` will raise
  a ValueError with the error message if an error occurred.

  By default only user-added :func:`~check` assertions are enabled. You can
  enable automatic checks through the ``errors`` argument.

  The automatic check sets which can be enabled, and when an error is generated:
    - ``user_checks``: a :func:`~check` evaluated to False.
    - ``nan_checks``: a floating-point operation generated a NaN value
      as output.
    - ``div_checks``: a division by zero.
    - ``index_checks``: an index was out-of-bounds.

  Multiple categories can be enabled together by passing in an error `Set` (eg.
  ``errors=nan_checks``). Multiple sets can be re-combined (eg.
  ``errors=float_checks|user_checks``)

  Args:
    fun: Callable which can contain user checks (see :func:`~check`).
    errors: A set of ErrorCategory values which defines the set of enabled
      checks. By default only explicit ``checks`` are enabled
      (``user_checks``). You can also for example enable NAN and
      DIV errors by passing the ``float_checks`` set, or for
      example combine multiple sets through set operations
      (``float_checks | user_checks``)
  Returns:
    A function which accepts the same arguments as ``fun`` and returns as output
    a pair where the first element is an ``Error`` value, representing the first
    failed :func:`~check`, and the second element is the original output of
    ``fun``.

  For example:

    >>> import jax
    >>> import jax.numpy as jnp
    >>> from jax.experimental import checkify
    >>>
    >>> @jax.jit
    ... def f(x):
    ...   y = jnp.sin(x)
    ...   return x+y
    >>> err, out = checkify.checkify(f, errors=checkify.float_checks)(jnp.inf)
    >>> err.throw()  # doctest: +IGNORE_EXCEPTION_DETAIL
    Traceback (most recent call last):
      ...
    jax._src.checkify.JaxRuntimeError: nan generated by primitive: sin
  """
  @traceback_util.api_boundary
  def checked_fun(*args, **kwargs):
    # close over all arguments so they're not turned into abstract values.
    in_tree = jtu.tree_structure(((), {}))
    closed_f = lambda: f(*args, **kwargs)
    # stage:
    fun_, out_tree = flatten_fun(lu.wrap_init(closed_f), in_tree)
    debug = pe.debug_info(closed_f, in_tree, out_tree, False, 'checkify')
    jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(fun_, (), debug)
    jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_))
    # checkify:
    error, out_flat = checkify_jaxpr(jaxpr, errors, init_error, *consts)
    return error, jtu.tree_unflatten(out_tree(), out_flat)
  return checked_fun

def check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None:
  """Check a predicate, add an error with msg if predicate is False.

  This is an effectful operation, and can't be staged (jitted/scanned/...).
  Before staging a function with checks, :func:`~checkify` it!

  Args:
    pred: if False, a FailedCheckError error is added.
    msg: error message if error is added. Can be a format string.
    fmt_args, fmt_kwargs: Positional and keyword formatting arguments for
      `msg`, eg.:
      ``check(.., "check failed on values {} and {named_arg}", x, named_arg=y)``
      Note that these arguments can be traced values allowing you to add
      run-time values to the error message.
      Note that tracking these run-time arrays will increase your memory usage,
      even if no error happens.

  For example:

    >>> import jax
    >>> import jax.numpy as jnp
    >>> from jax.experimental import checkify
    >>> def f(x):
    ...   checkify.check(x>0, "{x} needs to be positive!", x=x)
    ...   return 1/x
    >>> checked_f = checkify.checkify(f)
    >>> err, out = jax.jit(checked_f)(-3.)
    >>> err.throw()  # doctest: +IGNORE_EXCEPTION_DETAIL
    Traceback (most recent call last):
      ...
    jax._src.checkify.JaxRuntimeError: -3. needs to be positive!

  """
  _check(pred, msg, False, *fmt_args, **fmt_kwargs)

def _check(pred, msg, debug, *fmt_args, **fmt_kwargs):
  if not is_scalar_pred(pred):
    prim_name = 'debug_check' if debug else 'check'
    raise TypeError(f'{prim_name} takes a scalar pred as argument, got {pred}')
  for arg in jtu.tree_leaves((fmt_args, fmt_kwargs)):
    if not isinstance(arg, (Array, np.ndarray)):
      raise TypeError('Formatting arguments to checkify.check need to be '
                      'PyTrees of arrays, but got '
                      f'{arg!r} of type {type(arg)}.')
  new_error = FailedCheckError(get_traceback(), msg, *fmt_args, **fmt_kwargs)
  error = assert_func(init_error, jnp.logical_not(pred), new_error)
  _check_error(error, debug=debug)

def _check_error(error, *, debug=False):
  if any(map(np.shape, error._pred.values())):
    error = _reduce_any_error(error)
  err_args, tree_def = tree_flatten(error)

  return check_p.bind(*err_args, err_tree=tree_def, debug=debug)


def is_scalar_pred(pred) -> bool:
  return (isinstance(pred, bool) or
          isinstance(pred, Array) and pred.shape == () and
          pred.dtype == jnp.dtype('bool'))


def debug_check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None:
  """Check a predicate when running under checkify, otherwise is a no-op.

  A `debug_check` will only be run if it is transformed by :func:`~checkify`,
  otherwise the check will be dropped.

  Args:
    pred: if False, a FailedCheckError error is added.
    msg: error message if error is added.
    fmt_args, fmt_kwargs: Positional and keyword formatting arguments for
      `msg`, eg.:
      ``debug_check(.., "check failed on values {} and {named}", x, named=y)``
      Note that these arguments can be traced values allowing you to add
      run-time values to the error message.
      Note that tracking these run-time arrays will increase your memory usage,
      even if no error happens.

  For example:

    >>> import jax
    >>> import jax.numpy as jnp
    >>> from jax.experimental import checkify
    >>> def f(x):
    ...   checkify.debug_check(x!=0, "cannot be zero!")
    ...   return x
    >>> _ = f(0)  # running without checkify means no debug_check is run.
    >>> checked_f = checkify.checkify(f)
    >>> err, out = jax.jit(checked_f)(0)  # running with checkify runs debug_check.
    >>> err.throw()  # doctest: +IGNORE_EXCEPTION_DETAIL
    Traceback (most recent call last):
      ...
    jax._src.checkify.JaxRuntimeError: cannot be zero!

  """
  _check(pred, msg, True, *fmt_args, **fmt_kwargs)


def check_error(error: Error) -> None:
  """Raise an Exception if ``error`` represents a failure. Functionalized by :func:`~checkify`.

  The semantics of this function are equivalent to:

  >>> def check_error(err: Error) -> None:
  ...   err.throw()  # can raise ValueError

  But unlike that implementation, ``check_error`` can be functionalized using
  the :func:`~checkify` transformation.

  This function is similar to :func:`~check` but with a different signature: whereas
  :func:`~check` takes as arguments a boolean predicate and a new error message
  string, this function takes an ``Error`` value as argument. Both :func:`~check`
  and this function raise a Python Exception on failure (a side-effect), and
  thus cannot be staged out by :func:`~jax.jit`, :func:`~jax.pmap`,
  :func:`~jax.lax.scan`, etc. Both also can
  be functionalized by using :func:`~checkify`.

  But unlike :func:`~check`, this function is like a direct inverse of
  :func:`~checkify`:
  whereas :func:`~checkify` takes as input a function which
  can raise a Python
  Exception and produces a new function without that effect but which produces
  an ``Error`` value as output, this ``check_error`` function can accept an
  ``Error`` value as input and can produce the side-effect of raising an
  Exception. That is, while :func:`~checkify` goes from
  functionalizable Exception
  effect to error value, this ``check_error`` goes from error value to
  functionalizable Exception effect.

  ``check_error`` is useful when you want to turn checks represented by an
  ``Error`` value (produced by functionalizing ``checks`` via
  :func:`~checkify`) back into Python Exceptions.

  Args:
    error: Error to check.

  For example, you might want to functionalize part of your program through
  checkify, stage out your functionalized code through :func:`~jax.jit`, then
  re-inject your error value outside of the :func:`~jax.jit`:

  >>> import jax
  >>> from jax.experimental import checkify
  >>> def f(x):
  ...   checkify.check(x>0, "must be positive!")
  ...   return x
  >>> def with_inner_jit(x):
  ...   checked_f = checkify.checkify(f)
  ...   # a checkified function can be jitted
  ...   error, out = jax.jit(checked_f)(x)
  ...   checkify.check_error(error)
  ...   return out
  >>> _ = with_inner_jit(1)  # no failed check
  >>> with_inner_jit(-1)  # doctest: +IGNORE_EXCEPTION_DETAIL
  Traceback (most recent call last):
    ...
  jax._src.JaxRuntimeError: must be positive!
  >>> # can re-checkify
  >>> error, _ = checkify.checkify(with_inner_jit)(-1)
  """
  if not isinstance(error, Error):
    raise TypeError('check_error takes an Error as argument, '
                     f'got type {type(error)} instead.')
  _check_error(error, debug=False)
